其他
【他山之石】PyTorch 源码解读之 torch.optim:优化算法接口详解
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
地址:https://www.zhihu.com/people/openmmlab
01
1.0 基本用法
优化器主要是在模型训练阶段对模型可学习参数进行更新, 常用优化器有 SGD,RMSprop,Adam等 优化器初始化时传入传入模型的可学习参数,以及其他超参数如 lr,momentum等 在训练过程中先调用 optimizer.zero_grad() 清空梯度,再调用 loss.backward() 反向传播,最后调用 optimizer.step()更新模型参数
import torch
import numpy as np
import warnings
warnings.filterwarnings('ignore') #ignore warnings
x = torch.linspace(-np.pi, np.pi, 2000)
y = torch.sin(x)
p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)
model = torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Flatten(0, 1)
)
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
for t in range(1, 1001):
y_pred = model(xx)
loss = loss_fn(y_pred, y)
if t % 100 == 0:
print('No.{: 5d}, loss: {:.6f}'.format(t, loss.item()))
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传播计算梯度
optimizer.step() # 梯度下降法更新参数
No. 100, loss: 26215.714844
No. 200, loss: 11672.815430
No. 300, loss: 4627.826172
No. 400, loss: 1609.388062
No. 500, loss: 677.805115
No. 600, loss: 473.932159
No. 700, loss: 384.862396
No. 800, loss: 305.365143
No. 900, loss: 229.774719
No. 1000, loss: 161.483841
1.1 PyTorch 中的优化器
SGD ASGD Adadelta Adagrad Adam AdamW Adamax SparseAdam RMSprop Rprop LBFGS
1.2 父类Optimizer基本原理
add_param_group(param_group): 添加模型可学习参数组 step(closure): 进行一次参数更新 zero_grad(): 清空上次迭代记录的梯度信息 state_dict(): 返回 dict 结构的参数状态 load_state_dict(state_dict): 加载 dict 结构的参数状态
1.2.1 初始化 Optimizer
class Optimizer(object):
def __init__(self, params, defaults):
# 字典类型,子类传入,用于表示全部参数组的默认超参
self.defaults = defaults
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
self.param_groups = []
param_groups = list(params)
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
self.add_param_group(param_group)
1.2.2 add_param_group
def add_param_group(self, param_group):
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options.
"""
assert isinstance(param_group, dict), "param group must be a dict"
params = param_group['params']
if isinstance(params, torch.Tensor):
param_group['params'] = [params]
elif isinstance(params, set):
raise TypeError('optimizer parameters need to be organized in ordered collections, but '
'the ordering of tensors in sets will change between runs. Please use a list instead.')
else:
param_group['params'] = list(params)
for param in param_group['params']:
if not isinstance(param, torch.Tensor):
raise TypeError("optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param))
if not param.is_leaf:
raise ValueError("can't optimize a non-leaf Tensor")
# 利用默认参数给所有组设置统一的超参
for name, default in self.defaults.items():
if default is required and name not in param_group:
raise ValueError("parameter group didn't specify a value of required optimization parameter "+name)
else:
param_group.setdefault(name, default)
params = param_group['params']
if len(params) != len(set(params)):
warnings.warn("optimizer contains a parameter group with duplicate parameters; "
"in future, this will cause an error; "
"see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)
param_set = set()
for group in self.param_groups:
param_set.update(set(group['params']))
if not param_set.isdisjoint(set(param_group['params'])):
raise ValueError("some parameters appear in more than one parameter group")
self.param_groups.append(param_group)
from torch.optim import SGD
from torch import nn
class DummyModel(nn.Module):
def __init__(self, class_num=10):
super(DummyModel, self).__init__()
self.base = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, class_num)
def forward(self, x):
x = self.base(x)
x = self.gap(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
model = DummyModel().cuda()
optimizer = SGD([
{'params': model.base.parameters()},
{'params': model.fc.parameters(), 'lr': 1e-3} # 对 fc的参数设置不同的学习率
], lr=1e-2, momentum=0.9)
1.2.3 step
基类 Optimizer 定义了 step 方法接口,如下所示
def step(self, closure):
r"""Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
.. note::
Unless otherwise specified, this function should not modify the
``.grad`` field of the parameters.
"""
raise NotImplementedError
子类如 SGD 需要实现 step 方法,如下所示:
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
p.add_(d_p, alpha=-group['lr'])
return loss
step 方法可传入闭包函数 closure,主要目的是为了实现如Conjugate Gradient和LBFGS等优化算法,这些算法需要对模型进行多次评估 Python 中闭包概念:在一个内部函数中,对外部作用域的变量进行引用(并且一般外部函数的返回值为内部函数),那么内部函数就被认为是闭包
from torch.nn import CrossEntropyLoss
dummy_model = DummyModel().cuda()
optimizer = SGD(dummy_model.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-4)
# 定义loss
loss_fn = CrossEntropyLoss()
# 定义数据
batch_size = 2
data = torch.randn(64, 3, 64, 128).cuda() # 制造假数据shape=64 * 3 * 64 * 128
data_label = torch.randint(0, 10, size=(64,), dtype=torch.long).cuda() # 制造假的label
for batch_index in range(10):
batch_data = data[batch_index*batch_size: batch_index*batch_size + batch_size]
batch_label = data_label[batch_index*batch_size: batch_index*batch_size + batch_size]
def closure():
optimizer.zero_grad() # 清空梯度
output = dummy_model(batch_data) # forward
loss = loss_fn(output, batch_label) # 计算loss
loss.backward() # backward
print('No.{: 2d} loss: {:.6f}'.format(batch_index, loss.item()))
return loss
optimizer.step(closure=closure) # 更新参数
No. 0 loss: 2.279336
No. 1 loss: 2.278228
No. 2 loss: 2.291000
No. 3 loss: 2.245984
No. 4 loss: 2.236940
No. 5 loss: 2.104764
No. 6 loss: 2.227481
No. 7 loss: 2.108526
No. 8 loss: 2.254484
No. 9 loss: 2.536439
1.2.4 zero_grad
在反向传播计算梯度之前对上一次迭代时记录的梯度清零,参数set_to_none 设置为 True 时会直接将参数梯度设置为 None,从而减小内存使用, 但通常情况下不建议设置这个参数,因为梯度设置为 None 和 0 在 PyTorch 中处理逻辑会不一样。
def zero_grad(self, set_to_none: bool = False):
r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
Arguments:
set_to_none (bool): instead of setting to zero, set the grads to None.
This is will in general have lower memory footprint, and can modestly improve performance.
However, it changes certain behaviors. For example:
1. When the user tries to access a gradient and perform manual ops on it,
a None attribute or a Tensor full of 0s will behave differently.
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``s
are guaranteed to be None for params that did not receive a gradient.
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
(in one case it does the step with a gradient of 0 and in the other it skips
the step altogether).
"""
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()
1.2.5 state_dict() 和 load_state_dict
state_dict(): 将优化器管理的参数和其状态信息以 dict 形式返回 load_state_dict(state_dict): 加载之前返回的 dict,更新参数和其状态 两个方法可用来实现模型训练中断后继续训练功能
def state_dict(self):
r"""Returns the state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
# Save order indices instead of Tensors
param_mappings = {}
start_index = 0
def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
if id(p) not in param_mappings})
packed['params'] = [param_mappings[id(p)] for p in group['params']]
start_index += len(packed['params'])
return packed
param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
return {
'state': packed_state,
'param_groups': param_groups,
}
1.3 常见优化器简介
1.3.1 SGD(params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False)
1.3.2 Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10)
1.3.3 RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
1.3.4 Adam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False)
02
StepLR: 等间隔调整策略 MultiStepLR: 多阶段调整策略 ExponentialLR: 指数衰减调整策略 ReduceLROnPlateau: 自适应调整策略 CyclicLR: 循环调整策略 OneCycleLR: 单循环调整策略 CosineAnnealingLR: 余弦退火调整策略 CosineAnnealingWarmRestarts: 带预热启动的余弦退火调整策略 LambdaLR: 自定义函数调整策略 MultiplicativeLR: 乘法调整策略
有序调整策略: StepLR MultiStepLR ExponentialLR CyclicLR OneCycleLR CosineAnnealingLR CosineAnnealingWarmRestarts
自适应调整策略: ReduceLROnPlateau
自定义调整策略: LambdaLR MultiplicativeLR
2.1 基类: _LRScheduler
step(epoch=None): 子类公用 get_lr(): 子类需要实现 get_last_lr(): 子类公用 print_lr(is_verbose, group, lr, epoch=None): 显示 lr 调整信息 state_dict(): 子类可能会重写 load_state_dict(state_dict): 子类可能会重写
2.1.1 初始化
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1, verbose=False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
# Initialize epoch and base learning rates
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.last_epoch = last_epoch
# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(method):
if getattr(method, '_with_counter', False):
# `optimizer.step()` has already been replaced, return.
return method
# Keep a weak reference to the optimizer instance to prevent
# cyclic references.
instance_ref = weakref.ref(method.__self__)
# Get the unbound method for the same purpose.
func = method.__func__
cls = instance_ref().__class__
del method
@wraps(func)
def wrapper(*args, **kwargs):
instance = instance_ref()
instance._step_count += 1
wrapped = func.__get__(instance, cls)
return wrapped(*args, **kwargs)
# Note that the returned function here is no longer a bound method,
# so attributes like `__func__` and `__self__` no longer exist.
wrapper._with_counter = True
return wrapper
self.optimizer.step = with_counter(self.optimizer.step)
self.optimizer._step_count = 0
self._step_count = 0
self.verbose = verbose
self.step()
2.1.2 step
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn("...") # 移除了警告信息
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
warnings.warn("...") # 移除了警告信息
self._step_count += 1
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
def __enter__(self):
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = self._get_closed_form_lr()
else:
values = self.get_lr()
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
self.print_lr(self.verbose, i, lr, epoch)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
2.1.3 get_last_lr、get_lr和print_lr
get_last_lr()方法比较简单,就是step()方法调用后,记录的最后一次 optimizer各参数组里更新后的学习率信息 get_lr() 方法是抽象方法,定义了更新学习率策略的接口,不同子类继承后会有不同的实现.其返回值是[lr1, lr2, ...]结构 print_lr(is_verbose, group, lr, epoch=None)): 该方法提供了显示 lr 调整信息的功能
def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
"""
return self._last_lr
def get_lr(self):
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError
def print_lr(self, is_verbose, group, lr, epoch=None):
"""Display the current learning rate.
"""
if is_verbose:
if epoch is None:
print('Adjusting learning rate'
' of group {} to {:.4e}.'.format(group, lr))
else:
print('Epoch {:5d}: adjusting learning rate'
' of group {} to {:.4e}.'.format(epoch, group, lr))
2.1.4 state_dict 和 load_state_dict
state_dict(): 以字典 dict 形式返回当前实例除 self.optimizer 之外的其他所有属性信息 load_state_dict(state_dict): 重新载入之前保存的状态信息
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
2.2 学习率调整策略示例
2.2.1 StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False)
class StepLR(_LRScheduler):
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
self.step_size = step_size
self.gamma = gamma
super(StepLR, self).__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
for base_lr in self.base_lrs]
## 可视化学习率
from torch.optim import lr_scheduler
from matplotlib import pyplot as plt
%matplotlib inline
def create_optimizer():
return SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
def plot_lr(scheduler, title='', labels=['base'], nrof_epoch=100):
lr_li = [[] for _ in range(len(labels))]
epoch_li = list(range(nrof_epoch))
for epoch in epoch_li:
scheduler.step() # 调用step()方法,计算和更新optimizer管理的参数基于当前epoch的学习率
lr = scheduler.get_last_lr() # 获取当前epoch的学习率
for i in range(len(labels)):
lr_li[i].append(lr[i])
for lr, label in zip(lr_li, labels):
plt.plot(epoch_li, lr, label=label)
plt.grid()
plt.xlabel('epoch')
plt.ylabel('lr')
plt.title(title)
plt.legend()
plt.show()
## StepLR 可视化学习率
optimizer = create_optimizer()
scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
plot_lr(scheduler, title='StepLR')
2.2.2 MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False)
optimizer = create_optimizer()
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[20, 35, 45], gamma=0.5)
plot_lr(scheduler, title='MultiStepLR')
2.2.3 MultiplicativeLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)
optimizer = SGD([
{'params': model.base.parameters()},
{'params': model.fc.parameters(), 'lr': 0.05} # 对 fc的参数设置不同的学习率
], lr=0.1, momentum=0.9)
lambda_base = lambda epoch: 0.5 if epoch % 10 == 0 else 1
lambda_fc = lambda epoch: 0.8 if epoch % 10 == 0 else 1
scheduler = lr_scheduler.MultiplicativeLR(optimizer, [lambda_base, lambda_fc])
plot_lr(scheduler, title='MultiplicativeLR', labels=['base', 'fc'])
2.2.4 LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)
# LamdbdaLR调用示例
def lambda_foo(epoch):
if epoch < 10:
return (epoch+1) * 1e-3
elif epoch < 40:
return 1e-2
else:
return 1e-3
optimizer = create_optimizer()
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_foo)
plot_lr(scheduler, title='LambdaLR')
2.2.5 ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False)
optimizer = create_optimizer()
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
plot_lr(scheduler, title='ExponentialLR')
2.2.6 CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False)
optimizer = create_optimizer()
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 10, 1e-5)
plot_lr(scheduler, title='CosineAnnealingLR')
2.2.7 CosineAnnealingWarmRestarts(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False)
T_0: 第一次启动时的迭代数 T_mult: 启动后,改变周期 T 的因子 eta_min: 学习率下限
optimizer = create_optimizer()
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, 2)
plot_lr(scheduler, title='CosineAnnealingWarmRestarts')
2.2.8 CyclicLR(optimizer, base_lr, max_lr, step_size_up=2000, step_size_down=None, mode='triangular', ...)
base_lr: 基准学习率,也是最小的学习率 max_lr: 学习率上限 step_size_up: 一个周期里上升阶段 epoch 数 step_size_down: 一个周期里下降阶段 epoch 数
optimizer = create_optimizer()
scheduler = lr_scheduler.CyclicLR(optimizer, 0.01, 0.1, step_size_up=25, step_size_down=10)
plot_lr(scheduler, title='CyclicLR')
2.2.9 OneCycleLR(optimizer, max_lr, total_steps=None, epochs=None, steps_per_epoch=None, pct_start=0.3,...)
max_lr: float/list, 学习率调整的上限 total_steps: int 循环中的总步数
optimizer = create_optimizer()
scheduler = lr_scheduler.OneCycleLR(optimizer, 0.1, total_steps=100)
plot_lr(scheduler, title='OneCycleLR')
2.2.10 ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', ...)
mode: 评价模型训练质量的模式, 传入值为min 或max factor: 学习率衰减因子, 类似gamma patience: 控制何时调整学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(100):
train(...)
val_loss = validate(...)
scheduler.step(val_loss)
03
AveragedModel: 实现 SWA 算法的权重平均模型 SWALR: 与AverageModel配合使用的学习率调整策略 update_bn: 更新模型中的 bn
3.0 SWA 简介
3.1 AveragedModel
该类实现 SWA 算法的权重平均模型,初始化时传入模型 model 和参数平均化函数 avg_fn,然后在初始化函数中对 model的参数进行深拷贝, 注册模型计数器。 在update_parameters(self, model)方法中再次传入模型后,根据参数avg_fn对模型参数进行平均后更新 swa 模型参数。
class AveragedModel(Module):
def __init__(self, model, device=None, avg_fn=None):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
self.module = self.module.to(device)
self.register_buffer('n_averaged',
torch.tensor(0, dtype=torch.long, device=device))
# 默认提供了avg_fn,你可以指定
if avg_fn is None:
def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def update_parameters(self, model):
for p_swa, p_model in zip(self.parameters(), model.parameters()):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
self.n_averaged += 1
3.2 update_bn
def update_bn(loader, model, device=None):
momenta = {}
for module in model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)
momenta[module] = module.momentum
if not momenta:
return
was_training = model.training
model.train()
for module in momenta.keys():
module.momentum = None
module.num_batches_tracked *= 0
# 重新算BN全局均值和方差
for input in loader:
if isinstance(input, (list, tuple)):
input = input[0]
if device is not None:
input = input.to(device)
model(input)
for bn_module in momenta.keys():
bn_module.momentum = momenta[bn_module]
model.train(was_training)
loader, model = ...
torch.optim.swa_utils.update_bn(loader, model)
3.3 SWALR
Example:
>>> loader, optimizer, model = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> lr_lambda = lambda epoch: 0.9
>>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
>>> lr_lambda=lr_lambda)
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
>>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
>>> swa_start = 160
>>> for i in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> if i > swa_start:
>>> swa_scheduler.step()
>>> else:
>>> scheduler.step()
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
04
“他山之石”历史文章
AI框架基础技术之深度学习中的通信优化
SimCLR:用于视觉表征的对比学习框架
Pytorch Autograd与计算图
tensorflow2.4性能调优最佳实践
PyTorch | DDP系列:入门教程、实现原理与源代码解析、实战与技巧
教你用PyTorch玩转Transformer英译中翻译模型!
深度学习工程技巧之网格调参
PyTorch使用预训练模型进行模型加载
深度学习调参经验总结
PyTorch实现断点继续训练
Pytorch/Tensorflow-gpu训练并行加速trick(含代码)
从NumPy开始实现一个支持Auto-grad的CNN框架
pytorch_lightning 全程笔记
深度学习中的那些Trade-off
PyTorch 手把手搭建神经网络 (MNIST)
更多他山之石专栏文章,
请点击文章底部“阅读原文”查看
分享、点赞、在看,给个三连击呗!